//
// Copyright 2022 The Project Oak Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

#![no_std]
#![feature(associated_type_defaults)]
#![feature(error_in_core)]
#![feature(never_type)]
#![feature(try_blocks)]

extern crate alloc;

mod proto {
    #![allow(dead_code)]
    include!(concat!(env!("OUT_DIR"), "/micro_rpc.rs"));
}

mod status;

use alloc::boxed::Box;
pub use alloc::{format, vec::Vec};
pub use core::result::{
    Result,
    Result::{Err, Ok},
};

use prost::Message;
pub use proto::{response_wrapper, RequestWrapper, ResponseWrapper};
pub use status::{Status, StatusCode};

/// A message-oriented transport that allows performing invocations.
///
/// Each invocation consists of atomically sending opaque bytes and receiving
/// opaque bytes.
pub trait Transport {
    /// Type representing any transport-specific errors. By default, the
    /// transport is assumed to be infallible.
    type Error = !;
    fn invoke(&mut self, request_bytes: &[u8]) -> Result<Vec<u8>, Self::Error>;
}

/// Same as [`Transport`], but for async use cases.
#[async_trait::async_trait]
pub trait AsyncTransport {
    /// See [`Transport::Error`].
    type Error = !;
    /// See [`Transport::invoke`].
    async fn invoke(&mut self, request_bytes: &[u8]) -> Result<Vec<u8>, Self::Error>;
}

impl From<Status> for proto::Status {
    fn from(value: Status) -> Self {
        proto::Status { code: value.code as i32, message: value.message }
    }
}

impl From<proto::Status> for Status {
    fn from(value: proto::Status) -> Self {
        Status::new_with_message((value.code as u32).into(), value.message)
    }
}

impl From<proto::ResponseWrapper> for Result<Vec<u8>, Status> {
    fn from(value: proto::ResponseWrapper) -> Self {
        match value.response {
            None => Err(Status::new_with_message(
                StatusCode::InvalidArgument,
                "invalid response wrapper",
            )),
            Some(proto::response_wrapper::Response::Error(error)) => Err(error.into()),
            Some(proto::response_wrapper::Response::Body(body)) => Ok(body),
        }
    }
}

impl From<Result<Vec<u8>, Status>> for proto::ResponseWrapper {
    fn from(value: Result<Vec<u8>, Status>) -> Self {
        match value {
            Ok(body) => proto::ResponseWrapper {
                response: Some(proto::response_wrapper::Response::Body(body)),
            },
            Err(error) => proto::ResponseWrapper {
                response: Some(proto::response_wrapper::Response::Error(error.into())),
            },
        }
    }
}

/// Invokes the method identified by `method_id` via the provided [`Transport`],
/// taking care of the serialization and deserialization over the transport.
///
/// The return value has two layers of [`Result`]:
///
/// - the outer layer represents failure of the underlying tansport; if that
///   transport is infallible (i.e. the error variant is `!`), callers of this
///   function can just safely `unwrap` that layer; otherwise, callers may want
///   to handle tranport errors differently from invocation errors.
/// - the inner layer represents errors related to the invocation, usually
///   generated at the application level on the server side of the invocation.
///
/// This function is intended to be used by code generated by the
/// `micro_rpc_build` crate.
pub fn client_invoke<T: Transport, Req: prost::Message, Res: prost::Message + Default>(
    transport: &mut T,
    method_id: u32,
    request: &Req,
) -> Result<Result<Res, Status>, T::Error> {
    let request_body = request.encode_to_vec();
    let request_wrapper = RequestWrapper { method_id, body: request_body };
    let request_wrapper_bytes = request_wrapper.encode_to_vec();
    // This may result in tranport errors, corresponding to the outer Result layer.
    let response_wrapper_bytes = transport.invoke(&request_wrapper_bytes)?;
    let result: Result<Res, Status> = try {
        let response_wrapper =
            ResponseWrapper::decode(response_wrapper_bytes.as_ref()).map_err(|err| {
                Status::new_with_message(
                    StatusCode::Internal,
                    format!("Client failed to deserialize response wrapper: {}", err),
                )
            })?;
        let response_result: Result<Vec<u8>, Status> = response_wrapper.into();
        response_result.and_then(|body| {
            Res::decode(body.as_ref()).map_err(|err| {
                Status::new_with_message(
                    StatusCode::Internal,
                    format!("Client failed to deserialize response body: {}", err),
                )
            })
        })?
    };
    Ok(result)
}

/// Same as [`client_invoke`], but via an [`AsyncTransport`].
pub async fn async_client_invoke<
    T: AsyncTransport,
    Req: prost::Message,
    Res: prost::Message + Default,
>(
    transport: &mut T,
    method_id: u32,
    request: &Req,
) -> Result<Result<Res, Status>, T::Error> {
    let request_body = request.encode_to_vec();
    let request = RequestWrapper { method_id, body: request_body };
    let request_bytes = request.encode_to_vec();
    // This may result in tranport errors, corresponding to the outer Result layer.
    let response_bytes = transport.invoke(&request_bytes).await?;
    let result: Result<Res, Status> = try {
        let response = ResponseWrapper::decode(response_bytes.as_ref()).map_err(|err| {
            Status::new_with_message(
                StatusCode::Internal,
                format!("Client failed to deserialize response wrapper: {}", err),
            )
        })?;
        match response.response {
            Some(response_wrapper::Response::Error(err)) => Err(err.into()),
            Some(response_wrapper::Response::Body(body)) => {
                Res::decode(body.as_ref()).map_err(|err| {
                    Status::new_with_message(
                        StatusCode::Internal,
                        format!("Client failed to deserialize response body: {}", err),
                    )
                })
            }
            None => Err(Status::new(StatusCode::Internal)),
        }?
    };
    Ok(result)
}
